### Function for simulating networks ###
## Inputs
#'  BLK: integer; number of latent blocks 
#'  NODE: int.; number of nodes in net 
#'  STATE: int.; number of latent states in HMM 
#'  TIME: int.; number of time periods in net. history 
#'  DIRECTED: bool.; is network directed? 
#'  N_PRED: vector; two element vector, nr. monadic and dyadic predictors  
#'  B: matrix; BLK x BLK blockmodel matrix 
#'  sVec: vec.; TIME-long vector of HMM states  
#'  beta_arr: list of coefficient BLK x N_PRED[1] matrices  
#'  gamma_vec: vec.; N_PRED[2]-long vector of dyadic coefficients
#'  X: mat; NODE x N_PRED[1] matrix of node-level predictors  
#'  Z: mat; # of dyads x N_PRED[2] matrix of dyad-level predictors 
#'
## Output
#' List, with elements in the input list, plus
#' Y: a # of dyads-long binary vector of edges
#' pi_vecs: matrix of true mixed memberships 
#' theta: true probability of each edge
#' theta_full: probability of each edge for all latent groups    

NetSim2 <- function(BLK = 4,
                    NODE = 20,
                    STATE = 1,
                    TIME = 1,
                    DIRECTED = TRUE,
                    N_PRED = c(2, 2),
                    B = NULL,
                    sVec = NULL,
                    beta_arr = NULL,
                    gamma_vec = NULL,
                    X = NULL,
                    Z = NULL){
  require(MCMCpack)
  
  stopifnot(TIME >= STATE)
  
  ## Sample HMM states
  if(is.null(sVec)){
    s1v <- rnorm(STATE)
    s1 <- which.max(rmultinom(n=1, size=1, exp(s1v)/sum(exp(s1v))))
    if(STATE>1){
      A <- rdirichlet(STATE, rep(0.25, STATE))
      sVec <- c(s1, rep(NA, TIME-1))
      for(i in 2:length(sVec)){
        sVec[i] <- which.max(rmultinom(1, 1, A[sVec[i-1],]))
      }
    }
    if(STATE==1){sVec <- rep(s1, TIME); A <- as.matrix(1)}
  } 
  
  ## Sample monadic coefficients
  if(is.null(beta_arr)){
    beta_arr <- lapply(1:STATE, function(x){
      sapply(1:(N_PRED[1]+1), function(y){
        rnorm(BLK, 0, 3)})
    })
  }
  
  ## Sample (time-dependent) monadic predictors
  if(is.null(X)){
    X <- matrix(1, nrow=1, ncol=NODE*TIME)
    if(N_PRED[1] > 0){
      t1 <- list(sapply(1:NODE, function(x){rnorm(N_PRED[1], 0, 2)}))
      if(TIME > 1){
      for(t in 2:TIME){
        t1[[t]] <- t1[[t-1]] + rnorm(length(t1[[t-1]]), 0, 1)
      }
      }

      X <- rbind(X, do.call(c, t1))
    }
    colnames(X) <- paste(rep(1:NODE, TIME), rep(1:TIME, each=NODE), sep="_")
  }
  
  ## Sample mixed-membership vectors
  nodes <- 1:NODE
  pi <- lapply(nodes, function(x){
    Xsub <- as.data.frame(X[,paste(x, 1:TIME, sep="_")])
    debetas <- sapply(1:ncol(Xsub), function(y){
      ebeta <- exp(t(Xsub[,y]) %*% t(beta_arr[[sVec[y]]]))
      pi_vec <- rdirichlet(1, ebeta)
      if(anyNA(pi_vec)){
        cat_ind <- which.max(ebeta)
        pi_vec <- rep(0, length(ebeta))
        pi_vec[cat_ind] <- 1
      } 
        return(pi_vec)
    })
    colnames(debetas) <- paste(x, 1:TIME, sep="_")
    return(debetas)
  })
  pi.mat <- t(do.call(cbind, pi))
  
  ## Sample dyadic coefficients
  if(is.null(gamma_vec)){
    gamma_vec <- rnorm(N_PRED[2], -1, 1)
  }
  
  ## Sample dyadic predictors
  if(is.null(Z)){
    if(DIRECTED){dy <- expand.grid(1:NODE, 1:NODE)}
    if(!DIRECTED){dy <- t(combn(1:NODE, 2))}
    Z <- do.call(rbind, replicate(TIME, dy, simplify=FALSE))
    Z <- cbind(Z, rep(1:TIME, each = nrow(Z)/TIME), matrix(NA, nrow(Z), N_PRED[2]))
    for(d in 1:N_PRED[2]){
      t1 <- list(rnorm(nrow(Z)/TIME, 0, 2))
      if(TIME > 1){
      for(t in 2:TIME){
        t1[[t]] <- t1[[t-1]] + rnorm(length(t1[[t-1]]), 0, 1)
      }
      }
      Z[,d+3] <- do.call(c, t1)
    }
    colnames(Z) <- c("node1", "node2", "year", paste("V", 1:N_PRED[2], sep=""))
  }
  
  ## Sample dyad-specific block memberships (senders)
  z <- apply(Z, 1, function(x){
    which.max(rmultinom(1, 1, pi.mat[paste(x[1], x[3], sep="_"),]))
  })
  ## Sample dyad-specific block memberships (receivers)
  w <- apply(Z, 1, function(x){
    which.max(rmultinom(1, 1, pi.mat[paste(x[2], x[3], sep="_"),]))
  })
  
  ## Sample blockmodel
  if(is.null(B)){
    B <- matrix(rnorm(BLK*BLK), BLK, BLK)
    if(!DIRECTED){
      require(Matrix)
      B <- forceSymmetric(B)}
  }
  
  
  ##Sample edges
  dgam <- as.matrix(Z[,-c(1:3)]) %*% as.matrix(gamma_vec)
  theta <- lapply(B, function(x){exp(x + dgam) / (1 + exp(x + dgam))}) #element1 for B[1,1], 2 for B[2,1], etc.
  prob.edge <- mapply(function(a, b, c){theta[[which(B==B[a,b])[1]]][c]}, a=z, b=w, c=1:nrow(Z))
  Y <- sapply(prob.edge, function(x){rbinom(n=1, size=1, prob=x)})
  
  ##Form dataframes and matrices for return object
  dyad.data <- as.data.frame(Z)
  dyad.data$Y <- Y
  dyad.data$grp1 <- z
  dyad.data$grp2 <- w
  monad.data <- as.data.frame(t(X))
  monad.data$node <- as.numeric(unlist(lapply(strsplit(rownames(monad.data), "_"), "[[", 1)))
  monad.data$year <- as.numeric(unlist(lapply(strsplit(rownames(monad.data), "_"), "[[", 2)))
  pi.mat2 <- pi.mat[colnames(X),]
  for(i in 1:BLK){
    n <- paste("pi",i,sep="")
    monad.data[,n] <- pi.mat2[,i]
  }
  
  return(list(BLK = BLK
              ,NODE = NODE
              ,STATE = STATE
              ,TIME = TIME
              ,DYAD_PRED = N_PRED
              ,MONAD_PRED = N_PRED
              ,DIRECTED = DIRECTED
              ,Y = Y
              ,sVec = sVec
              ,B = B
              ,beta_arr = beta_arr
              ,gamma_mat = gamma_vec
              ,dyad.data = dyad.data
              ,monad.data = monad.data
              ,X = X
              ,Z = Z
              ,pi_vecs = pi.mat2
              ,theta = prob.edge
              ,theta_full = theta
  ))

  
}
                    
      
